{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fdfefd86",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from utils import myGRU4Rec\n",
    "\n",
    "experiment_main = 'jp:es:mx:uk:ca:us2br:in:it:de:au:fr'\n",
    "\n",
    "F = 768\n",
    "L = 2\n",
    "\n",
    "confounder = 2\n",
    "db = 1\n",
    "f = 128\n",
    "\n",
    "r_list = [1.0, 10.0, 100.0, 1000.0]\n",
    "\n",
    "train_domain_list = ['jp','es','mx','uk','ca','us']\n",
    "train_prefix = ':'.join(train_domain_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b8fb960a",
   "metadata": {},
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: 'experiments/jp:es:mx:uk:ca:us2br:in:it:de:au:fr/c=2_db=1_r=1.0_f=128/jp:es:mx:uk:ca:us/model.pt'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_35972/3514127744.py\u001b[0m in \u001b[0;36m<cell line: 6>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m     \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmyGRU4Rec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mL\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtrain_domain_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m     \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_folder\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;34m'/model.pt'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m     \u001b[0mmean_norm_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit_domain\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[1;32m    592\u001b[0m         \u001b[0mpickle_load_args\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'encoding'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'utf-8'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    593\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 594\u001b[0;31m     \u001b[0;32mwith\u001b[0m \u001b[0m_open_file_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mopened_file\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    595\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0m_is_zipfile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    596\u001b[0m             \u001b[0;31m# The zipfile reader is going to advance the current file position.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m_open_file_like\u001b[0;34m(name_or_buffer, mode)\u001b[0m\n\u001b[1;32m    228\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_open_file_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    229\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0m_is_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 230\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    231\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    232\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;34m'w'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, mode)\u001b[0m\n\u001b[1;32m    209\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_opener\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    210\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 211\u001b[0;31m         \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_open_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    212\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    213\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__exit__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'experiments/jp:es:mx:uk:ca:us2br:in:it:de:au:fr/c=2_db=1_r=1.0_f=128/jp:es:mx:uk:ca:us/model.pt'"
     ]
    }
   ],
   "source": [
    "mean_norm_list = []\n",
    "domain_norm_list = {}\n",
    "for train_domain in train_domain_list:\n",
    "    domain_norm_list[train_domain] =[]\n",
    "\n",
    "for r in r_list:\n",
    "\n",
    "    experiment_sub = 'c='+str(confounder)+'_db='+str(db)+'_r='+str(r)+'_f='+str(f)\n",
    "\n",
    "    save_folder = 'experiments/'+experiment_main+'/'+experiment_sub+'/'+train_prefix\n",
    "\n",
    "    model = myGRU4Rec(F,f,L,train_domain_list)\n",
    "    model.load_state_dict(torch.load(save_folder+'/model.pt'))\n",
    "    \n",
    "    mean_norm_list.append(torch.sum(model.init_domain**2).detach().numpy())\n",
    "    \n",
    "    for i in np.arange(model.init_domain.shape[0]):\n",
    "        domain_norm_list[train_domain_list[i]].append(torch.sum(model.init_domain[i]**2).detach().numpy())\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cb1f4e38",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "x and y must have same first dimension, but have shapes (4,) and (0,)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_35972/1368443190.py\u001b[0m in \u001b[0;36m<cell line: 4>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog10\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mr\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mr_list\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog10\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mnorm\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmean_norm_list\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mxlabel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'np.log10(regularization)'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mylabel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'np.log10(mean norm)'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/matplotlib/pyplot.py\u001b[0m in \u001b[0;36mplot\u001b[0;34m(scalex, scaley, data, *args, **kwargs)\u001b[0m\n\u001b[1;32m   2755\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0m_copy_docstring_and_deprecators\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mAxes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2756\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscalex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscaley\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2757\u001b[0;31m     return gca().plot(\n\u001b[0m\u001b[1;32m   2758\u001b[0m         \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscalex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mscalex\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscaley\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mscaley\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2759\u001b[0m         **({\"data\": data} if data is not None else {}), **kwargs)\n",
      "\u001b[0;32m~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/matplotlib/axes/_axes.py\u001b[0m in \u001b[0;36mplot\u001b[0;34m(self, scalex, scaley, data, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1630\u001b[0m         \"\"\"\n\u001b[1;32m   1631\u001b[0m         \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcbook\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnormalize_kwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmlines\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLine2D\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1632\u001b[0;31m         \u001b[0mlines\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_lines\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1633\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mline\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mlines\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1634\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_line\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mline\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/matplotlib/axes/_base.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, data, *args, **kwargs)\u001b[0m\n\u001b[1;32m    310\u001b[0m                 \u001b[0mthis\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    311\u001b[0m                 \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 312\u001b[0;31m             \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_plot_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mthis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    313\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    314\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mget_next_color\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/matplotlib/axes/_base.py\u001b[0m in \u001b[0;36m_plot_args\u001b[0;34m(self, tup, kwargs, return_kwargs)\u001b[0m\n\u001b[1;32m    496\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    497\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 498\u001b[0;31m             raise ValueError(f\"x and y must have same first dimension, but \"\n\u001b[0m\u001b[1;32m    499\u001b[0m                              f\"have shapes {x.shape} and {y.shape}\")\n\u001b[1;32m    500\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: x and y must have same first dimension, but have shapes (4,) and (0,)"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAANQklEQVR4nO3cX4il9X3H8fenuxEak0aJk5DurmRb1pi90KITI6VpTUObXXuxBLxQQ6QSWKQx5FIpNLnwprkohKBmWWSR3GQvGkk2ZRMplMSCNd1Z8N8qynSlOl3BNYYUDFRWv704p51hnHWenXNmZp3v+wUD85znNzPf+TH73mfPznlSVUiStr7f2ewBJEkbw+BLUhMGX5KaMPiS1ITBl6QmDL4kNbFq8JMcSfJakmfPcz5JvptkPsnTSa6b/piSpEkNucJ/GNj3Huf3A3vGbweB700+liRp2lYNflU9BrzxHksOAN+vkSeAy5J8YloDSpKmY/sUPscO4JUlxwvjx15dvjDJQUb/CuDSSy+9/uqrr57Cl5ekPk6ePPl6Vc2s5WOnEfys8NiK92uoqsPAYYDZ2dmam5ubwpeXpD6S/OdaP3Yav6WzAOxacrwTODOFzytJmqJpBP8YcMf4t3VuBH5TVe96OkeStLlWfUonyQ+Am4ArkiwA3wI+AFBVh4DjwM3APPBb4M71GlaStHarBr+qblvlfAFfm9pEkqR14SttJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJamJQ8JPsS/JCkvkk965w/iNJfpLkqSSnktw5/VElSZNYNfhJtgEPAPuBvcBtSfYuW/Y14Lmquha4CfiHJJdMeVZJ0gSGXOHfAMxX1emqegs4ChxYtqaADycJ8CHgDeDcVCeVJE1kSPB3AK8sOV4YP7bU/cCngTPAM8A3quqd5Z8oycEkc0nmzp49u8aRJUlrMST4WeGxWnb8ReBJ4PeBPwLuT/J77/qgqsNVNVtVszMzMxc4qiRpEkOCvwDsWnK8k9GV/FJ3Ao/UyDzwEnD1dEaUJE3DkOCfAPYk2T3+j9hbgWPL1rwMfAEgyceBTwGnpzmoJGky21dbUFXnktwNPApsA45U1akkd43PHwLuAx5O8gyjp4DuqarX13FuSdIFWjX4AFV1HDi+7LFDS94/A/zldEeTJE2Tr7SVpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDUxKPhJ9iV5Icl8knvPs+amJE8mOZXkF9MdU5I0qe2rLUiyDXgA+AtgATiR5FhVPbdkzWXAg8C+qno5ycfWaV5J0hoNucK/AZivqtNV9RZwFDiwbM3twCNV9TJAVb023TElSZMaEvwdwCtLjhfGjy11FXB5kp8nOZnkjpU+UZKDSeaSzJ09e3ZtE0uS1mRI8LPCY7XseDtwPfBXwBeBv0ty1bs+qOpwVc1W1ezMzMwFDytJWrtVn8NndEW/a8nxTuDMCmter6o3gTeTPAZcC7w4lSklSRMbcoV/AtiTZHeSS4BbgWPL1vwY+FyS7Uk+CHwWeH66o0qSJrHqFX5VnUtyN/AosA04UlWnktw1Pn+oqp5P8jPgaeAd4KGqenY9B5ckXZhULX86fmPMzs7W3NzcpnxtSXq/SnKyqmbX8rG+0laSmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmBgU/yb4kLySZT3Lve6z7TJK3k9wyvRElSdOwavCTbAMeAPYDe4Hbkuw9z7pvA49Oe0hJ0uSGXOHfAMxX1emqegs4ChxYYd3XgR8Cr01xPknSlAwJ/g7glSXHC+PH/l+SHcCXgEPv9YmSHEwyl2Tu7NmzFzqrJGkCQ4KfFR6rZcffAe6pqrff6xNV1eGqmq2q2ZmZmYEjSpKmYfuANQvAriXHO4Ezy9bMAkeTAFwB3JzkXFX9aBpDSpImNyT4J4A9SXYD/wXcCty+dEFV7f6/95M8DPyTsZeki8uqwa+qc0nuZvTbN9uAI1V1Ksld4/Pv+by9JOniMOQKn6o6Dhxf9tiKoa+qv558LEnStPlKW0lqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSE4OCn2RfkheSzCe5d4XzX07y9Pjt8STXTn9USdIkVg1+km3AA8B+YC9wW5K9y5a9BPxZVV0D3AccnvagkqTJDLnCvwGYr6rTVfUWcBQ4sHRBVT1eVb8eHz4B7JzumJKkSQ0J/g7glSXHC+PHzuerwE9XOpHkYJK5JHNnz54dPqUkaWJDgp8VHqsVFyafZxT8e1Y6X1WHq2q2qmZnZmaGTylJmtj2AWsWgF1LjncCZ5YvSnIN8BCwv6p+NZ3xJEnTMuQK/wSwJ8nuJJcAtwLHli5IciXwCPCVqnpx+mNKkia16hV+VZ1LcjfwKLANOFJVp5LcNT5/CPgm8FHgwSQA56pqdv3GliRdqFSt+HT8upudna25ublN+dqS9H6V5ORaL6h9pa0kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNDAp+kn1JXkgyn+TeFc4nyXfH559Oct30R5UkTWLV4CfZBjwA7Af2Arcl2bts2X5gz/jtIPC9Kc8pSZrQkCv8G4D5qjpdVW8BR4EDy9YcAL5fI08AlyX5xJRnlSRNYPuANTuAV5YcLwCfHbBmB/Dq0kVJDjL6FwDA/yR59oKm3bquAF7f7CEuEu7FIvdikXux6FNr/cAhwc8Kj9Ua1lBVh4HDAEnmqmp2wNff8tyLRe7FIvdikXuxKMncWj92yFM6C8CuJcc7gTNrWCNJ2kRDgn8C2JNkd5JLgFuBY8vWHAPuGP+2zo3Ab6rq1eWfSJK0eVZ9SqeqziW5G3gU2AYcqapTSe4anz8EHAduBuaB3wJ3Dvjah9c89dbjXixyLxa5F4vci0Vr3otUveupdknSFuQrbSWpCYMvSU2se/C9LcOiAXvx5fEePJ3k8STXbsacG2G1vViy7jNJ3k5yy0bOt5GG7EWSm5I8meRUkl9s9IwbZcCfkY8k+UmSp8Z7MeT/C993khxJ8tr5Xqu05m5W1bq9MfpP3v8A/gC4BHgK2Ltszc3ATxn9Lv+NwC/Xc6bNehu4F38MXD5+f3/nvViy7l8Y/VLALZs99yb+XFwGPAdcOT7+2GbPvYl78bfAt8fvzwBvAJds9uzrsBd/ClwHPHue82vq5npf4XtbhkWr7kVVPV5Vvx4fPsHo9Qxb0ZCfC4CvAz8EXtvI4TbYkL24HXikql4GqKqtuh9D9qKADycJ8CFGwT+3sWOuv6p6jNH3dj5r6uZ6B/98t1y40DVbwYV+n19l9Df4VrTqXiTZAXwJOLSBc22GIT8XVwGXJ/l5kpNJ7tiw6TbWkL24H/g0oxd2PgN8o6re2ZjxLipr6uaQWytMYmq3ZdgCBn+fST7PKPh/sq4TbZ4he/Ed4J6qent0MbdlDdmL7cD1wBeA3wX+LckTVfXieg+3wYbsxReBJ4E/B/4Q+Ock/1pV/73Os11s1tTN9Q6+t2VYNOj7THIN8BCwv6p+tUGzbbQhezELHB3H/grg5iTnqupHGzLhxhn6Z+T1qnoTeDPJY8C1wFYL/pC9uBP4+xo9kT2f5CXgauDfN2bEi8aaurneT+l4W4ZFq+5FkiuBR4CvbMGrt6VW3Yuq2l1Vn6yqTwL/CPzNFow9DPsz8mPgc0m2J/kgo7vVPr/Bc26EIXvxMqN/6ZDk44zuHHl6Q6e8OKypm+t6hV/rd1uG952Be/FN4KPAg+Mr23O1Be8QOHAvWhiyF1X1fJKfAU8D7wAPVdWWu7X4wJ+L+4CHkzzD6GmNe6pqy902OckPgJuAK5IsAN8CPgCTddNbK0hSE77SVpKaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrifwHXe3WluIZOawAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "plt.figure()\n",
    "plt.plot([np.log10(r) for r in r_list], [np.log10(norm) for norm in mean_norm_list])\n",
    "plt.xlabel('np.log10(regularization)')\n",
    "plt.ylabel('np.log10(mean norm)')\n",
    "plt.title('effect of regularization on domain confounder vector')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbe13369",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "plt.figure()\n",
    "for key in domain_norm_list.keys():\n",
    "    plt.plot([np.log10(r) for r in r_list], [np.log10(norm) for norm in domain_norm_list[key]], label=key)\n",
    "    plt.xlabel('np.log10(regularization)')\n",
    "    plt.ylabel('np.log10(mean norm)')\n",
    "plt.title('effect of regularization on domain confounder vector')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "2a178ec5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.9545, grad_fn=<DivBackward0>)\n",
      "tensor(0.6982, grad_fn=<DivBackward0>)\n",
      "tensor(0.8696, grad_fn=<DivBackward0>)\n",
      "tensor(0.7935, grad_fn=<DivBackward0>)\n",
      "tensor(1.0477, grad_fn=<DivBackward0>)\n",
      "tensor(0.7315, grad_fn=<DivBackward0>)\n",
      "tensor(0.9111, grad_fn=<DivBackward0>)\n",
      "tensor(0.8313, grad_fn=<DivBackward0>)\n",
      "tensor(1.4323, grad_fn=<DivBackward0>)\n",
      "tensor(1.3671, grad_fn=<DivBackward0>)\n",
      "tensor(1.2456, grad_fn=<DivBackward0>)\n",
      "tensor(1.1365, grad_fn=<DivBackward0>)\n",
      "tensor(1.1499, grad_fn=<DivBackward0>)\n",
      "tensor(1.0976, grad_fn=<DivBackward0>)\n",
      "tensor(0.8029, grad_fn=<DivBackward0>)\n",
      "tensor(0.9124, grad_fn=<DivBackward0>)\n",
      "tensor(1.2603, grad_fn=<DivBackward0>)\n",
      "tensor(1.2029, grad_fn=<DivBackward0>)\n",
      "tensor(0.8799, grad_fn=<DivBackward0>)\n",
      "tensor(1.0960, grad_fn=<DivBackward0>)\n",
      "tensor(1.4204, grad_fn=<DivBackward0>)\n",
      "tensor(1.2034, grad_fn=<DivBackward0>)\n",
      "tensor(1.3373, grad_fn=<DivBackward0>)\n",
      "tensor(1.4069, grad_fn=<DivBackward0>)\n",
      "tensor(0.7040, grad_fn=<DivBackward0>)\n",
      "tensor(0.8473, grad_fn=<DivBackward0>)\n",
      "tensor(0.9415, grad_fn=<DivBackward0>)\n",
      "tensor(0.9905, grad_fn=<DivBackward0>)\n",
      "tensor(0.8310, grad_fn=<DivBackward0>)\n",
      "tensor(1.1803, grad_fn=<DivBackward0>)\n",
      "tensor(1.1112, grad_fn=<DivBackward0>)\n",
      "tensor(1.1691, grad_fn=<DivBackward0>)\n",
      "tensor(0.7478, grad_fn=<DivBackward0>)\n",
      "tensor(1.0621, grad_fn=<DivBackward0>)\n",
      "tensor(0.8999, grad_fn=<DivBackward0>)\n",
      "tensor(1.0521, grad_fn=<DivBackward0>)\n",
      "tensor(0.7108, grad_fn=<DivBackward0>)\n",
      "tensor(1.0095, grad_fn=<DivBackward0>)\n",
      "tensor(0.8554, grad_fn=<DivBackward0>)\n",
      "tensor(0.9505, grad_fn=<DivBackward0>)\n",
      "tensor(1.4881, grad_fn=<DivBackward0>)\n",
      "tensor(1.1029, grad_fn=<DivBackward0>)\n",
      "tensor(1.0666, grad_fn=<DivBackward0>)\n",
      "tensor(0.9775, grad_fn=<DivBackward0>)\n",
      "tensor(0.6720, grad_fn=<DivBackward0>)\n",
      "tensor(0.7412, grad_fn=<DivBackward0>)\n",
      "tensor(0.7167, grad_fn=<DivBackward0>)\n",
      "tensor(0.6569, grad_fn=<DivBackward0>)\n",
      "tensor(0.9067, grad_fn=<DivBackward0>)\n",
      "tensor(1.3492, grad_fn=<DivBackward0>)\n",
      "tensor(0.9670, grad_fn=<DivBackward0>)\n",
      "tensor(0.8863, grad_fn=<DivBackward0>)\n",
      "tensor(0.9376, grad_fn=<DivBackward0>)\n",
      "tensor(1.3952, grad_fn=<DivBackward0>)\n",
      "tensor(1.0341, grad_fn=<DivBackward0>)\n",
      "tensor(0.9165, grad_fn=<DivBackward0>)\n",
      "tensor(1.0230, grad_fn=<DivBackward0>)\n",
      "tensor(1.5223, grad_fn=<DivBackward0>)\n",
      "tensor(1.1283, grad_fn=<DivBackward0>)\n",
      "tensor(1.0911, grad_fn=<DivBackward0>)\n",
      "tensor(1.7236, grad_fn=<DivBackward0>)\n",
      "tensor(1.5078, grad_fn=<DivBackward0>)\n",
      "tensor(1.3352, grad_fn=<DivBackward0>)\n",
      "tensor(1.3381, grad_fn=<DivBackward0>)\n",
      "tensor(0.5802, grad_fn=<DivBackward0>)\n",
      "tensor(0.8748, grad_fn=<DivBackward0>)\n",
      "tensor(0.7746, grad_fn=<DivBackward0>)\n",
      "tensor(0.7763, grad_fn=<DivBackward0>)\n",
      "tensor(0.6632, grad_fn=<DivBackward0>)\n",
      "tensor(1.1431, grad_fn=<DivBackward0>)\n",
      "tensor(0.8855, grad_fn=<DivBackward0>)\n",
      "tensor(0.8874, grad_fn=<DivBackward0>)\n",
      "tensor(0.7490, grad_fn=<DivBackward0>)\n",
      "tensor(1.2909, grad_fn=<DivBackward0>)\n",
      "tensor(1.1293, grad_fn=<DivBackward0>)\n",
      "tensor(1.0021, grad_fn=<DivBackward0>)\n",
      "tensor(0.7474, grad_fn=<DivBackward0>)\n",
      "tensor(1.2882, grad_fn=<DivBackward0>)\n",
      "tensor(1.1269, grad_fn=<DivBackward0>)\n",
      "tensor(0.9979, grad_fn=<DivBackward0>)\n",
      "tensor(1.5378, grad_fn=<DivBackward0>)\n",
      "tensor(1.1706, grad_fn=<DivBackward0>)\n",
      "tensor(1.0720, grad_fn=<DivBackward0>)\n",
      "tensor(1.1923, grad_fn=<DivBackward0>)\n",
      "tensor(0.6503, grad_fn=<DivBackward0>)\n",
      "tensor(0.7613, grad_fn=<DivBackward0>)\n",
      "tensor(0.6971, grad_fn=<DivBackward0>)\n",
      "tensor(0.7754, grad_fn=<DivBackward0>)\n",
      "tensor(0.8542, grad_fn=<DivBackward0>)\n",
      "tensor(1.3136, grad_fn=<DivBackward0>)\n",
      "tensor(0.9157, grad_fn=<DivBackward0>)\n",
      "tensor(1.0186, grad_fn=<DivBackward0>)\n",
      "tensor(0.9329, grad_fn=<DivBackward0>)\n",
      "tensor(1.4345, grad_fn=<DivBackward0>)\n",
      "tensor(1.0920, grad_fn=<DivBackward0>)\n",
      "tensor(1.1123, grad_fn=<DivBackward0>)\n",
      "tensor(0.8387, grad_fn=<DivBackward0>)\n",
      "tensor(1.2897, grad_fn=<DivBackward0>)\n",
      "tensor(0.9818, grad_fn=<DivBackward0>)\n",
      "tensor(0.8990, grad_fn=<DivBackward0>)\n",
      "tensor(1.7731, grad_fn=<DivBackward0>)\n",
      "tensor(1.1759, grad_fn=<DivBackward0>)\n",
      "tensor(1.1774, grad_fn=<DivBackward0>)\n",
      "tensor(1.3068, grad_fn=<DivBackward0>)\n",
      "tensor(0.5640, grad_fn=<DivBackward0>)\n",
      "tensor(0.6632, grad_fn=<DivBackward0>)\n",
      "tensor(0.6640, grad_fn=<DivBackward0>)\n",
      "tensor(0.7370, grad_fn=<DivBackward0>)\n",
      "tensor(0.8504, grad_fn=<DivBackward0>)\n",
      "tensor(1.5079, grad_fn=<DivBackward0>)\n",
      "tensor(1.0013, grad_fn=<DivBackward0>)\n",
      "tensor(1.1113, grad_fn=<DivBackward0>)\n",
      "tensor(0.8493, grad_fn=<DivBackward0>)\n",
      "tensor(1.5060, grad_fn=<DivBackward0>)\n",
      "tensor(0.9987, grad_fn=<DivBackward0>)\n",
      "tensor(1.1099, grad_fn=<DivBackward0>)\n",
      "tensor(0.7652, grad_fn=<DivBackward0>)\n",
      "tensor(1.3569, grad_fn=<DivBackward0>)\n",
      "tensor(0.8998, grad_fn=<DivBackward0>)\n",
      "tensor(0.9010, grad_fn=<DivBackward0>)\n"
     ]
    }
   ],
   "source": [
    "vs = model.init_domain\n",
    "for i in np.arange(6):\n",
    "    for j in np.arange(6):\n",
    "        for k in np.arange(6):\n",
    "            if (i!=j) and (i!=k) and (j!=k):\n",
    "                print(torch.sum((vs[i]-vs[j])**2)/torch.sum((vs[i]-vs[k])**2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfa79363",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_pytorch_p38",
   "language": "python",
   "name": "conda_pytorch_p38"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
